// Package mssql is the advisor for MSSQL database.
package mssql

import (
	"fmt"

	"github.com/antlr4-go/antlr/v4"
	parser "github.com/bytebase/tsql-parser"
	"github.com/pkg/errors"

	"github.com/bytebase/bytebase/backend/plugin/advisor"
	tsqlparser "github.com/bytebase/bytebase/backend/plugin/parser/tsql"
	storepb "github.com/bytebase/bytebase/proto/generated-go/store"
)

var (
	_ advisor.Advisor = (*ColumnNoNullAdvisor)(nil)
)

func init() {
	advisor.Register(storepb.Engine_MSSQL, advisor.MSSQLColumnNoNull, &ColumnNoNullAdvisor{})
}

// ColumnNoNullAdvisor is the advisor checking for column no NULL value..
type ColumnNoNullAdvisor struct {
}

// Check checks for column no NULL value..
func (*ColumnNoNullAdvisor) Check(ctx advisor.Context, _ string) ([]*storepb.Advice, error) {
	tree, ok := ctx.AST.(antlr.Tree)
	if !ok {
		return nil, errors.Errorf("failed to convert to Tree")
	}

	level, err := advisor.NewStatusBySQLReviewRuleLevel(ctx.Rule.Level)
	if err != nil {
		return nil, err
	}

	listener := &columnNoNullChecker{
		level:                            level,
		title:                            string(ctx.Rule.Type),
		currentNormalizedTableName:       "",
		isCurrentTableColumnNullable:     make(map[string]bool),
		currentTableColumnIsNullableLine: make(map[string]int),
		adviceList:                       make([]*storepb.Advice, 0),
	}

	antlr.ParseTreeWalkerDefault.Walk(listener, tree)

	return listener.generateAdvice()
}

// columnNoNullChecker is the listener for column no NULL value.
type columnNoNullChecker struct {
	*parser.BaseTSqlParserListener

	level storepb.Advice_Status
	title string
	// currentNormalizedTableName is the normalized table name of the current table.
	currentNormalizedTableName string
	// isCurrentTableColumnNullable is the map of column name to whether the column is nullable.
	isCurrentTableColumnNullable map[string]bool
	// currentTableColumnIsNullableLine is the map of column name to the line number of the column definition.
	currentTableColumnIsNullableLine map[string]int

	adviceList []*storepb.Advice
}

// generateAdvice returns the advices generated by the listener, the advices must not be empty.
func (l *columnNoNullChecker) generateAdvice() ([]*storepb.Advice, error) {
	return l.adviceList, nil
}

// EnterCreate_table is called when production create_table is entered.
func (l *columnNoNullChecker) EnterCreate_table(ctx *parser.Create_tableContext) {
	tableName := ctx.Table_name()
	if tableName == nil {
		return
	}
	normalizedTableName := tsqlparser.NormalizeTSQLTableName(tableName, "" /* fallbackDatabase */, "dbo" /* fallbackSchema */, false /* caseSensitive */)

	l.currentNormalizedTableName = normalizedTableName
}

// ExitCreate_table is called when production create_table is exited.
func (l *columnNoNullChecker) ExitCreate_table(_ *parser.Create_tableContext) {
	l.currentNormalizedTableName = ""
	for columnName, isNullable := range l.isCurrentTableColumnNullable {
		if !isNullable {
			continue
		}
		l.adviceList = append(l.adviceList, &storepb.Advice{
			Status:  l.level,
			Code:    advisor.ColumnCannotNull.Int32(),
			Title:   l.title,
			Content: fmt.Sprintf("Column [%s] is nullable, which is not allowed.", columnName),
			StartPosition: &storepb.Position{
				Line: int32(l.currentTableColumnIsNullableLine[columnName]),
			},
		})
	}

	l.isCurrentTableColumnNullable = make(map[string]bool)
	l.currentTableColumnIsNullableLine = make(map[string]int)
}

// EnterTable_constraint is called when production table_constraint is entered.
func (l *columnNoNullChecker) EnterTable_constraint(ctx *parser.Table_constraintContext) {
	parent := ctx.GetParent()
	switch parent.(type) {
	case *parser.Column_def_table_constraintContext:
	default:
		return
	}
	if ctx.PRIMARY() != nil {
		allColumns := ctx.Column_name_list_with_order().AllId_()
		for _, column := range allColumns {
			_, columnName := tsqlparser.NormalizeTSQLIdentifier(column)
			l.isCurrentTableColumnNullable[columnName] = false
		}
	}
}

// EnterColumn_definition is called when production column_definition is entered.
func (l *columnNoNullChecker) EnterColumn_definition(ctx *parser.Column_definitionContext) {
	parent := ctx.GetParent()
	switch parent.(type) {
	case *parser.Alter_tableContext:
	case *parser.Column_def_table_constraintContext:
	default:
		return
	}
	_, columnName := tsqlparser.NormalizeTSQLIdentifier(ctx.Id_())
	l.isCurrentTableColumnNullable[columnName] = true
	l.currentTableColumnIsNullableLine[columnName] = ctx.Id_().GetStart().GetLine()
	allColumnDefinitionElements := ctx.AllColumn_definition_element()
	for _, columnDefinitionElement := range allColumnDefinitionElements {
		if v := columnDefinitionElement.Column_constraint(); v != nil {
			if v.PRIMARY() != nil {
				l.isCurrentTableColumnNullable[columnName] = false
				break
			}
			if (v.Null_notnull() != nil && v.Null_notnull().NOT() != nil) || v.Null_notnull() == nil {
				l.isCurrentTableColumnNullable[columnName] = false
				break
			}
		}
	}
}

// EnterAlter_table is called when production alter_table is entered.
func (l *columnNoNullChecker) EnterAlter_table(ctx *parser.Alter_tableContext) {
	tableName := ctx.Table_name(0)
	if tableName == nil {
		return
	}
	if (len(ctx.AllALTER()) == 2 && ctx.COLUMN() != nil) /* ALTER COLUMN */ || (len(ctx.AllALTER()) == 1 && ctx.ADD() != nil && ctx.WITH() == nil) /* ALTER */ {
		normalizedTableName := tsqlparser.NormalizeTSQLTableName(tableName, "" /* fallbackDatabase */, "dbo" /* fallbackSchema */, false /* caseSensitive */)
		l.currentNormalizedTableName = normalizedTableName
	}
}

// ExitAlter_table is called when production alter_table is exited.
func (l *columnNoNullChecker) ExitAlter_table(_ *parser.Alter_tableContext) {
	l.currentNormalizedTableName = ""
	for columnName, isNullable := range l.isCurrentTableColumnNullable {
		if !isNullable {
			continue
		}
		l.adviceList = append(l.adviceList, &storepb.Advice{
			Status:  l.level,
			Code:    advisor.ColumnCannotNull.Int32(),
			Title:   l.title,
			Content: fmt.Sprintf("Column [%s] is nullable, which is not allowed.", columnName),
			StartPosition: &storepb.Position{
				Line: int32(l.currentTableColumnIsNullableLine[columnName]),
			},
		})
	}

	l.isCurrentTableColumnNullable = make(map[string]bool)
	l.currentTableColumnIsNullableLine = make(map[string]int)
}
